{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "http://nlp.seas.harvard.edu/2018/04/03/attention.html\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import torchtext\n",
    "from torchtext.datasets import TranslationDataset, Multi30k\n",
    "from torchtext.data import Field, BucketIterator\n",
    "\n",
    "import spacy\n",
    "\n",
    "import random\n",
    "import math\n",
    "import os\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 1\n",
    "\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "spacy_de = spacy.load('de')\n",
    "spacy_en = spacy.load('en')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_de(text):\n",
    "    \"\"\"\n",
    "    Tokenizes German text from a string into a list of strings\n",
    "    \"\"\"\n",
    "    return [tok.text for tok in spacy_de.tokenizer(text)]\n",
    "\n",
    "def tokenize_en(text):\n",
    "    \"\"\"\n",
    "    Tokenizes English text from a string into a list of strings\n",
    "    \"\"\"\n",
    "    return [tok.text for tok in spacy_en.tokenizer(text)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "SRC = Field(tokenize=tokenize_de, init_token='<sos>', eos_token='<eos>', lower=True, batch_first=True)\n",
    "TRG = Field(tokenize=tokenize_en, init_token='<sos>', eos_token='<eos>', lower=True, batch_first=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data, valid_data, test_data = Multi30k.splits(exts=('.de', '.en'), fields=(SRC, TRG))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "SRC.build_vocab(train_data, min_freq=2)\n",
    "TRG.build_vocab(train_data, min_freq=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 128\n",
    "\n",
    "train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n",
    "    (train_data, valid_data, test_data), \n",
    "     batch_size=BATCH_SIZE,\n",
    "     device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, encoder_layer, self_attention, positionwise_feedforward, dropout, device):\n",
    "        super().__init__()\n",
    "\n",
    "        self.input_dim = input_dim\n",
    "        self.hid_dim = hid_dim\n",
    "        self.n_layers = n_layers\n",
    "        self.n_heads = n_heads\n",
    "        self.pf_dim = pf_dim\n",
    "        self.encoder_layer = encoder_layer\n",
    "        self.self_attention = self_attention\n",
    "        self.positionwise_feedforward = positionwise_feedforward\n",
    "        self.dropout = dropout\n",
    "        self.device = device\n",
    "        \n",
    "        self.tok_embedding = nn.Embedding(input_dim, hid_dim)\n",
    "        self.pos_embedding = nn.Embedding(1000, hid_dim)\n",
    "        \n",
    "        self.layers = nn.ModuleList([encoder_layer(hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout, device) \n",
    "                                     for _ in range(n_layers)])\n",
    "        \n",
    "        self.do = nn.Dropout(dropout)\n",
    "        \n",
    "        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)\n",
    "        \n",
    "    def forward(self, src, src_mask):\n",
    "        \n",
    "        #src = [batch size, src sent len]\n",
    "        #src_mask = [batch size, src sent len]\n",
    "        \n",
    "        pos = torch.arange(0, src.shape[1]).unsqueeze(0).repeat(src.shape[0], 1).to(self.device)\n",
    "        \n",
    "        src = self.do((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))\n",
    "        \n",
    "        #src = [batch size, src sent len, hid dim]\n",
    "        \n",
    "        for layer in self.layers:\n",
    "            src = layer(src, src_mask)\n",
    "            \n",
    "        return src"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EncoderLayer(nn.Module):\n",
    "    def __init__(self, hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout, device):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.ln = nn.LayerNorm(hid_dim)\n",
    "        self.sa = self_attention(hid_dim, n_heads, dropout, device)\n",
    "        self.pf = positionwise_feedforward(hid_dim, pf_dim, dropout)\n",
    "        self.do = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, src, src_mask):\n",
    "        \n",
    "        #src = [batch size, src sent len, hid dim]\n",
    "        #src_mask = [batch size, src sent len]\n",
    "        \n",
    "        src = self.ln(src + self.do(self.sa(src, src, src, src_mask)))\n",
    "        \n",
    "        src = self.ln(src + self.do(self.pf(src)))\n",
    "        \n",
    "        return src"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SelfAttention(nn.Module):\n",
    "    def __init__(self, hid_dim, n_heads, dropout, device):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.hid_dim = hid_dim\n",
    "        self.n_heads = n_heads\n",
    "        \n",
    "        assert hid_dim % n_heads == 0\n",
    "        \n",
    "        self.w_q = nn.Linear(hid_dim, hid_dim)\n",
    "        self.w_k = nn.Linear(hid_dim, hid_dim)\n",
    "        self.w_v = nn.Linear(hid_dim, hid_dim)\n",
    "        \n",
    "        self.fc = nn.Linear(hid_dim, hid_dim)\n",
    "        \n",
    "        self.do = nn.Dropout(dropout)\n",
    "        \n",
    "        self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).to(device)\n",
    "        \n",
    "    def forward(self, query, key, value, mask=None):\n",
    "        \n",
    "        bsz = query.shape[0]\n",
    "        \n",
    "        #query = key = value [batch size, sent len, hid dim]\n",
    "                \n",
    "        Q = self.w_q(query)\n",
    "        K = self.w_k(key)\n",
    "        V = self.w_v(value)\n",
    "        \n",
    "        #Q, K, V = [batch size, sent len, hid dim]\n",
    "        \n",
    "        Q = Q.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)\n",
    "        K = K.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)\n",
    "        V = V.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)\n",
    "        \n",
    "        #Q, K, V = [batch size, n heads, sent len, hid dim // n heads]\n",
    "        \n",
    "        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale\n",
    "        \n",
    "        #energy = [batch size, n heads, sent len, sent len]\n",
    "        \n",
    "        if mask is not None:\n",
    "            energy = energy.masked_fill(mask == 0, -1e10)\n",
    "        \n",
    "        attention = self.do(F.softmax(energy, dim=-1))\n",
    "        \n",
    "        #attention = [batch size, n heads, sent len, sent len]\n",
    "        \n",
    "        x = torch.matmul(attention, V)\n",
    "        \n",
    "        #x = [batch size, n heads, sent len, hid dim // n heads]\n",
    "        \n",
    "        x = x.permute(0, 2, 1, 3).contiguous()\n",
    "        \n",
    "        #x = [batch size, sent len, n heads, hid dim // n heads]\n",
    "        \n",
    "        x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads))\n",
    "        \n",
    "        #x = [batch size, src sent len, hid dim]\n",
    "        \n",
    "        x = self.fc(x)\n",
    "        \n",
    "        #x = [batch size, sent len, hid dim]\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PositionwiseFeedforward(nn.Module):\n",
    "    def __init__(self, hid_dim, pf_dim, dropout):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.hid_dim = hid_dim\n",
    "        self.pf_dim = pf_dim\n",
    "        \n",
    "        self.fc_1 = nn.Conv1d(hid_dim, pf_dim, 1)\n",
    "        self.fc_2 = nn.Conv1d(pf_dim, hid_dim, 1)\n",
    "        \n",
    "        self.do = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        \n",
    "        #x = [batch size, sent len, hid dim]\n",
    "        \n",
    "        x = x.permute(0, 2, 1)\n",
    "        \n",
    "        #x = [batch size, hid dim, sent len]\n",
    "        \n",
    "        x = self.do(F.relu(self.fc_1(x)))\n",
    "        \n",
    "        #x = [batch size, ff dim, sent len]\n",
    "        \n",
    "        x = self.fc_2(x)\n",
    "        \n",
    "        #x = [batch size, hid dim, sent len]\n",
    "        \n",
    "        x = x.permute(0, 2, 1)\n",
    "        \n",
    "        #x = [batch size, sent len, hid dim]\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):\n",
    "    def __init__(self, output_dim, hid_dim, n_layers, n_heads, pf_dim, decoder_layer, self_attention, positionwise_feedforward, dropout, device):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.output_dim = output_dim\n",
    "        self.hid_dim = hid_dim\n",
    "        self.n_layers = n_layers\n",
    "        self.n_heads = n_heads\n",
    "        self.pf_dim = pf_dim\n",
    "        self.decoder_layer = decoder_layer\n",
    "        self.self_attention = self_attention\n",
    "        self.positionwise_feedforward = positionwise_feedforward\n",
    "        self.dropout = dropout\n",
    "        self.device = device\n",
    "        \n",
    "        self.tok_embedding = nn.Embedding(output_dim, hid_dim)\n",
    "        self.pos_embedding = nn.Embedding(1000, hid_dim)\n",
    "        \n",
    "        self.layers = nn.ModuleList([decoder_layer(hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout, device)\n",
    "                                     for _ in range(n_layers)])\n",
    "        \n",
    "        self.fc = nn.Linear(hid_dim, output_dim)\n",
    "        \n",
    "        self.do = nn.Dropout(dropout)\n",
    "        \n",
    "        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)\n",
    "        \n",
    "    def forward(self, trg, src, trg_mask, src_mask):\n",
    "        \n",
    "        #trg = [batch_size, trg sent len]\n",
    "        #src = [batch_size, src sent len]\n",
    "        #trg_mask = [batch size, trg sent len]\n",
    "        #src_mask = [batch size, src sent len]\n",
    "        \n",
    "        pos = torch.arange(0, trg.shape[1]).unsqueeze(0).repeat(trg.shape[0], 1).to(self.device)\n",
    "                \n",
    "        trg = self.do((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))\n",
    "        \n",
    "        #trg = [batch size, trg sent len, hid dim]\n",
    "        \n",
    "        for layer in self.layers:\n",
    "            trg = layer(trg, src, trg_mask, src_mask)\n",
    "            \n",
    "        return self.fc(trg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DecoderLayer(nn.Module):\n",
    "    def __init__(self, hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout, device):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.ln = nn.LayerNorm(hid_dim)\n",
    "        self.sa = self_attention(hid_dim, n_heads, dropout, device)\n",
    "        self.ea = self_attention(hid_dim, n_heads, dropout, device)\n",
    "        self.pf = positionwise_feedforward(hid_dim, pf_dim, dropout)\n",
    "        self.do = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, trg, src, trg_mask, src_mask):\n",
    "        \n",
    "        #trg = [batch size, trg sent len, hid dim]\n",
    "        #src = [batch size, src sent len, hid dim]\n",
    "        #trg_mask = [batch size, trg sent len]\n",
    "        #src_mask = [batch size, src sent len]\n",
    "                \n",
    "        trg = self.ln(trg + self.do(self.sa(trg, trg, trg, trg_mask)))\n",
    "                \n",
    "        trg = self.ln(trg + self.do(self.ea(trg, src, src, src_mask)))\n",
    "        \n",
    "        trg = self.ln(trg + self.do(self.pf(trg)))\n",
    "        \n",
    "        return trg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2Seq(nn.Module):\n",
    "    def __init__(self, encoder, decoder, pad_idx, device):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        self.pad_idx = pad_idx\n",
    "        self.device = device\n",
    "        \n",
    "    def make_masks(self, src, trg):\n",
    "        \n",
    "        #src = [batch size, src sent len]\n",
    "        #trg = [batch size, trg sent len]\n",
    "        \n",
    "        src_mask = (src != self.pad_idx).unsqueeze(1).unsqueeze(2)\n",
    "        \n",
    "        trg_pad_mask = (trg != self.pad_idx).unsqueeze(1).unsqueeze(3)\n",
    "        \n",
    "        #src_mask = [batch size, 1, 1, src sent len]\n",
    "        #trg_pad_mask = [batch size, 1, trg sent len, 1]\n",
    "        \n",
    "        trg_len = trg.shape[1]\n",
    "        \n",
    "        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=self.device)).bool()\n",
    "                        \n",
    "        #trg_sub_mask = [trg sent len, trg sent len]\n",
    "            \n",
    "        trg_mask = trg_pad_mask & trg_sub_mask\n",
    "        \n",
    "        #trg_mask = [batch size, 1, trg sent len, trg sent len]\n",
    "        \n",
    "        return src_mask, trg_mask\n",
    "    \n",
    "    def forward(self, src, trg):\n",
    "        \n",
    "        #src = [batch size, src sent len]\n",
    "        #trg = [batch size, trg sent len]\n",
    "                \n",
    "        src_mask, trg_mask = self.make_masks(src, trg)\n",
    "        \n",
    "        enc_src = self.encoder(src, src_mask)\n",
    "        \n",
    "        #enc_src = [batch size, src sent len, hid dim]\n",
    "                \n",
    "        out = self.decoder(trg, enc_src, trg_mask, src_mask)\n",
    "        \n",
    "        #out = [batch size, trg sent len, output dim]\n",
    "        \n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim = len(SRC.vocab)\n",
    "hid_dim = 512\n",
    "n_layers = 6\n",
    "n_heads = 8\n",
    "pf_dim = 2048\n",
    "dropout = 0.1\n",
    "\n",
    "enc = Encoder(input_dim, hid_dim, n_layers, n_heads, pf_dim, EncoderLayer, SelfAttention, PositionwiseFeedforward, dropout, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dim = len(TRG.vocab)\n",
    "hid_dim = 512\n",
    "n_layers = 6\n",
    "n_heads = 8\n",
    "pf_dim = 2048\n",
    "dropout = 0.1\n",
    "\n",
    "dec = Decoder(output_dim, hid_dim, n_layers, n_heads, pf_dim, DecoderLayer, SelfAttention, PositionwiseFeedforward, dropout, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_idx = SRC.vocab.stoi['<pad>']\n",
    "\n",
    "model = Seq2Seq(enc, dec, pad_idx, device).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model has 55,206,149 trainable parameters\n"
     ]
    }
   ],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(f'The model has {count_parameters(model):,} trainable parameters')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "for p in model.parameters():\n",
    "    if p.dim() > 1:\n",
    "        nn.init.xavier_uniform_(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NoamOpt:\n",
    "    \"Optim wrapper that implements rate.\"\n",
    "    def __init__(self, model_size, factor, warmup, optimizer):\n",
    "        self.optimizer = optimizer\n",
    "        self._step = 0\n",
    "        self.warmup = warmup\n",
    "        self.factor = factor\n",
    "        self.model_size = model_size\n",
    "        self._rate = 0\n",
    "        \n",
    "    def step(self):\n",
    "        \"Update parameters and rate\"\n",
    "        self._step += 1\n",
    "        rate = self.rate()\n",
    "        for p in self.optimizer.param_groups:\n",
    "            p['lr'] = rate\n",
    "        self._rate = rate\n",
    "        self.optimizer.step()\n",
    "        \n",
    "    def rate(self, step = None):\n",
    "        \"Implement `lrate` above\"\n",
    "        if step is None:\n",
    "            step = self._step\n",
    "        return self.factor * \\\n",
    "            (self.model_size ** (-0.5) *\n",
    "            min(step ** (-0.5), step * self.warmup ** (-1.5)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = NoamOpt(hid_dim, 1, 2000,\n",
    "            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, iterator, optimizer, criterion, clip):\n",
    "    \n",
    "    model.train()\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    \n",
    "    for i, batch in enumerate(iterator):\n",
    "        \n",
    "        src = batch.src\n",
    "        trg = batch.trg\n",
    "        \n",
    "        optimizer.optimizer.zero_grad()\n",
    "        \n",
    "        output = model(src, trg[:,:-1])\n",
    "                \n",
    "        #output = [batch size, trg sent len - 1, output dim]\n",
    "        #trg = [batch size, trg sent len]\n",
    "            \n",
    "        output = output.contiguous().view(-1, output.shape[-1])\n",
    "        trg = trg[:,1:].contiguous().view(-1)\n",
    "                \n",
    "        #output = [batch size * trg sent len - 1, output dim]\n",
    "        #trg = [batch size * trg sent len - 1]\n",
    "            \n",
    "        loss = criterion(output, trg)\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        epoch_loss += loss.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, iterator, criterion):\n",
    "    \n",
    "    model.eval()\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "    \n",
    "        for i, batch in enumerate(iterator):\n",
    "\n",
    "            src = batch.src\n",
    "            trg = batch.trg\n",
    "\n",
    "            output = model(src, trg[:,:-1])\n",
    "            \n",
    "            #output = [batch size, trg sent len - 1, output dim]\n",
    "            #trg = [batch size, trg sent len]\n",
    "            \n",
    "            output = output.contiguous().view(-1, output.shape[-1])\n",
    "            trg = trg[:,1:].contiguous().view(-1)\n",
    "            \n",
    "            #output = [batch size * trg sent len - 1, output dim]\n",
    "            #trg = [batch size * trg sent len - 1]\n",
    "            \n",
    "            loss = criterion(output, trg)\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def epoch_time(start_time, end_time):\n",
    "    elapsed_time = end_time - start_time\n",
    "    elapsed_mins = int(elapsed_time / 60)\n",
    "    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
    "    return elapsed_mins, elapsed_secs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 01 | Time: 0m 51s\n",
      "\tTrain Loss: 5.908 | Train PPL: 367.799\n",
      "\t Val. Loss: 4.098 |  Val. PPL:  60.237\n",
      "Epoch: 02 | Time: 0m 51s\n",
      "\tTrain Loss: 3.779 | Train PPL:  43.785\n",
      "\t Val. Loss: 3.201 |  Val. PPL:  24.557\n",
      "Epoch: 03 | Time: 0m 51s\n",
      "\tTrain Loss: 3.137 | Train PPL:  23.041\n",
      "\t Val. Loss: 2.780 |  Val. PPL:  16.125\n",
      "Epoch: 04 | Time: 0m 51s\n",
      "\tTrain Loss: 2.765 | Train PPL:  15.873\n",
      "\t Val. Loss: 2.596 |  Val. PPL:  13.409\n",
      "Epoch: 05 | Time: 0m 50s\n",
      "\tTrain Loss: 2.499 | Train PPL:  12.174\n",
      "\t Val. Loss: 2.404 |  Val. PPL:  11.072\n",
      "Epoch: 06 | Time: 0m 51s\n",
      "\tTrain Loss: 2.311 | Train PPL:  10.089\n",
      "\t Val. Loss: 2.321 |  Val. PPL:  10.185\n",
      "Epoch: 07 | Time: 0m 51s\n",
      "\tTrain Loss: 2.184 | Train PPL:   8.880\n",
      "\t Val. Loss: 2.306 |  Val. PPL:  10.032\n",
      "Epoch: 08 | Time: 0m 50s\n",
      "\tTrain Loss: 2.103 | Train PPL:   8.194\n",
      "\t Val. Loss: 2.306 |  Val. PPL:  10.031\n",
      "Epoch: 09 | Time: 0m 51s\n",
      "\tTrain Loss: 2.066 | Train PPL:   7.892\n",
      "\t Val. Loss: 2.309 |  Val. PPL:  10.069\n",
      "Epoch: 10 | Time: 0m 50s\n",
      "\tTrain Loss: 2.007 | Train PPL:   7.441\n",
      "\t Val. Loss: 2.295 |  Val. PPL:   9.929\n"
     ]
    }
   ],
   "source": [
    "N_EPOCHS = 10\n",
    "CLIP = 1\n",
    "\n",
    "best_valid_loss = float('inf')\n",
    "\n",
    "for epoch in range(N_EPOCHS):\n",
    "    \n",
    "    start_time = time.time()\n",
    "    \n",
    "    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)\n",
    "    valid_loss = evaluate(model, valid_iterator, criterion)\n",
    "    \n",
    "    end_time = time.time()\n",
    "    \n",
    "    epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
    "    \n",
    "    if valid_loss < best_valid_loss:\n",
    "        best_valid_loss = valid_loss\n",
    "        torch.save(model.state_dict(), 'tut6-model.pt')\n",
    "    \n",
    "    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')\n",
    "    print(f'\\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')\n",
    "    print(f'\\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Test Loss: 2.284 | Test PPL:   9.812 |\n"
     ]
    }
   ],
   "source": [
    "model.load_state_dict(torch.load('tut6-model.pt'))\n",
    "\n",
    "test_loss = evaluate(model, test_iterator, criterion)\n",
    "\n",
    "print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
